import torch
from torch.utils import data
import os


from shutil import copyfile
import os.path as osp
from PIL import Image
from PIL import ImageFile
from functools import partial

ImageFile.LOAD_TRUNCATED_IMAGES = True


def loader(path):
    return Image.open(path).convert('RGB')

def this_loader(path,dataset_root, preload_dir):
    dest_path = osp.join(preload_dir, osp.relpath(path, dataset_root))
    os.makedirs(osp.dirname(dest_path), exist_ok=True)

    if not osp.isfile(dest_path):
        copyfile(path, dest_path)

    image = loader(dest_path)
    return image

def preloader(dataset_root, preload_dir):
    return partial(this_loader, dataset_root=dataset_root,preload_dir=preload_dir)

class ADDS(data.Dataset):
    """
    ImageNet 1K dataset
    Classes numbered from 0 to 999 inclusive
    Can deal both with ImageNet original structure and the common "reorganized" validation dataset
    Caches list of files for faster reloading.
    """
    NUM_CLASSES = 1
    MEAN = [0.485, 0.456, 0.406]
    STD = [0.229, 0.224, 0.225]


    def __init__(self, root, split='train', transform=None, force_reindex=False,
                 preload=False):
        self.root = root
        self.transform = transform
        self.split = split
        cachefile = '/private/home/ymasano/monov2/data/AD-' + root.replace('/','-') + split + '-cached-list.pth'
        self.classes, self.class_to_idx, self.imgs, self.labels, self.images_subdir = self.get_dataset(cachefile, force_reindex)
        self.loader = loader if not preload else preloader(root, '/tmp/data')

    def get_dataset(self, cachefile=None, force_reindex=False):
        if osp.isfile(cachefile) and not force_reindex:
            print('Loaded AD {} dataset from cache: {}...'.format(self.split, cachefile))
            return torch.load(cachefile)

        print('Indexing AD {} dataset...'.format(self.split))
        for images_subdir in [self.split ]:
            print(osp.join(self.root, images_subdir))
            if osp.isdir(osp.join(self.root, images_subdir)):
                break
        else:
            raise ValueError('Split {} not found'.format(self.split))
        self.images_subdir = images_subdir
        subfiles = os.listdir(osp.join(self.root, images_subdir))


        if osp.isdir(osp.join(self.root, images_subdir, subfiles[0])):  # ImageFolder
            classes = [folder for folder in subfiles]
            classes.sort()
            class_to_idx = {c: i for (i, c) in enumerate(classes)}
            imgs = []
            labels = []
            for label in classes:
                label_images = os.listdir(osp.join(self.root, images_subdir, label))
                label_images.sort()
                imgs.extend([osp.join(label, i) for i in label_images])
                labels.extend([class_to_idx[label] for _ in label_images])
        print('OK!')
        returns = (classes, class_to_idx, imgs, labels, images_subdir)
        if cachefile is not None:
            if torch.distributed.get_rank() == 0:
                print('rank: 0')
                os.makedirs(osp.dirname(cachefile), exist_ok=True)
                torch.save(returns, cachefile)
                print(f' cached to {cachefile}')
        print()
        return returns

    def __getitem__(self, idx):
        image = self.loader(osp.join(self.root, self.images_subdir, self.imgs[idx]))
        if self.transform is not None:
            image = self.transform(image)
        return (image, self.labels[idx])

    def __len__(self):
        return len(self.imgs)

    def __repr__(self):
        return "AD(root='{}', split='{}')".format(self.root, self.split)